{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Few-shot Text Classification\n",
    "\n",
    "# Table of contents \n",
    "\n",
    "1. [Introduction to the task](#1.-Introduction-to-the-task)\n",
    "\n",
    "2. [Get started with the model](#2.-Get-started-with-the-model)\n",
    "\n",
    "3. [Models list](#3.-Models-list)\n",
    "\n",
    "4. [Use the model for prediction](#4.-Use-the-model-for-prediction)\n",
    "\n",
    "    4.1 [Dataset format](#4.1-Dataset-format)\n",
    "\n",
    "    4.2. [Predict using Python](#4.2-Predict-using-Python)\n",
    "    \n",
    "    4.3. [Predict using CLI](#4.3-Predict-using-CLI)\n",
    "\n",
    "5. [Customize the model](#5.-Customize-the-model)\n",
    "\n",
    "# 1. Introduction to the task\n",
    "\n",
    "__Text classification__ is a task of identifying one of the pre-defined label given an utterance, where label is one of N classes or \"OOS\" (out-of-scope examples - utterances that do not belong to any of the predefined classes). We consider few-shot setting, where only few examples (5 or 10) per intent class are given as a training set.\n",
    "\n",
    "\n",
    "# 2. Get started with the model\n",
    "\n",
    "First make sure you have the DeepPavlov Library installed.\n",
    "[More info about the first installation.](http://docs.deeppavlov.ai/en/master/intro/installation.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q deeppavlov"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then make sure that all the required packages are installed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -m deeppavlov install few_shot_roberta"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`few_shot_roberta` is the name of the model's *config_file*. [What is a Config File?](http://docs.deeppavlov.ai/en/master/intro/configuration.html) \n",
    "\n",
    "Configuration file defines the model and describes its hyperparameters. To use another model, change the name of the *config_file* here and further.\n",
    "Some of few-shot classification models with their config names can be found in the [table](#3.-Models-list).\n",
    "\n",
    "# 3. Models list\n",
    "\n",
    "At the moment, only `few_shot_roberta` config support out-of-scope detection.\n",
    "\n",
    "| Config name  | Dataset | Shot | Model Size | In-domain accuracy | Out-of-scope recall | Out-of-scope precision |\n",
    "| :--- | --- | --- | --- | --- |  --- | ---: |\n",
    "| few_shot_roberta| [CLINC150-Banking-Domain](https://paperswithcode.com/paper/an-evaluation-dataset-for-intent)  | 5 | 1.4 GB | 84.1±1.9 | 93.2±0.8 | 97.8±0.3 |\n",
    "| few_shot_roberta| [CLINC150](https://paperswithcode.com/paper/an-evaluation-dataset-for-intent)  | 5 | 1.4 GB | 59.4±1.4 | 87.9±1.2 | 40.3±0.7 |\n",
    "| few_shot_roberta| [BANKING77-OOS](https://paperswithcode.com/paper/are-pretrained-transformers-robust-in-intent)  | 5 | 1.4 GB | 51.4±2.1 | 93.7±0.7 | 82.7±1.4 |\n",
    "| fasttext_logreg*| [CLINC150-Banking-Domain](https://paperswithcode.com/paper/an-evaluation-dataset-for-intent)  | 5 | 37 KB |24.8±2.2 | 98.2±0.4 | 74.8±0.6 |\n",
    "| fasttext_logreg*| [CLINC150](https://paperswithcode.com/paper/an-evaluation-dataset-for-intent)  | 5 | 37 KB | 13.4±0.5 | 98.6±0.2 | 20.5±0.1 |\n",
    "| fasttext_logreg*| [BANKING77-OOS](https://paperswithcode.com/paper/are-pretrained-transformers-robust-in-intent)  | 5 | 37 KB |10.7±0.8 | 99.0±0.3 | 36.4±0.2 |\n",
    "\n",
    "\n",
    "With zero threshold we can get a classification accuracy without OOS detection:\n",
    "\n",
    "| Config name  | Dataset | Shot | Model Size | Accuracy |\n",
    "| :--- | --- | --- | --- | ---: |\n",
    "| few_shot_roberta| [CLINC150-Banking-Domain](https://paperswithcode.com/paper/an-evaluation-dataset-for-intent)  | 5 | 1.4 GB | 89.6 |\n",
    "| few_shot_roberta| [CLINC150](https://paperswithcode.com/paper/an-evaluation-dataset-for-intent)  | 5 | 1.4 GB | 79.6 |\n",
    "| few_shot_roberta| [BANKING77-OOS](https://paperswithcode.com/paper/are-pretrained-transformers-robust-in-intent)  | 5 | 1.4 GB | 55.1 |\n",
    "| fasttext_logreg*| [CLINC150-Banking-Domain](https://paperswithcode.com/paper/an-evaluation-dataset-for-intent)  | 5 | 37 KB | 86.3 |\n",
    "| fasttext_logreg*| [CLINC150](https://paperswithcode.com/paper/an-evaluation-dataset-for-intent)  | 5 | 37 KB | 73.6\n",
    "| fasttext_logreg*| [BANKING77-OOS](https://paperswithcode.com/paper/are-pretrained-transformers-robust-in-intent)  | 5 | 37 KB | 51.6 |\n",
    "\n",
    "\\* \\- config file was modified to predict OOS examples\n",
    "\n",
    "\n",
    "# 4. Use the model for prediction\n",
    "\n",
    "Base model `few_shot_roberta` was already pre-trained to recognize simmilar utterances, so you can use off-the-shelf model to make predictions and evalutation. No additional training needed.\n",
    "\n",
    "## 4.1 Dataset format\n",
    "\n",
    "DNNC model compares input text to every example in dataset to determine, which class the input example belongs to. The dataset based on which classification is performed has the following format:\n",
    "\n",
    "```\n",
    "[\n",
    "    [\"text_1\",  \"label_1\"],\n",
    "    [\"text_2\",  \"label_2\"],\n",
    "             ...\n",
    "    [\"text_n\",  \"label_n\"]\n",
    "]\n",
    "```\n",
    "\n",
    "## 4.2 Predict using Python\n",
    "\n",
    "After [installing](#2.-Get-started-with-the-model) the model, build it from the config and predict."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from deeppavlov import build_model\n",
    "\n",
    "model = build_model(\"few_shot_roberta\", download=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you set `download` flag to `True`, then existing model weights will be overwritten.\n",
    "\n",
    "Setting the `install` argument to `True` is equivalent to executing the command line `install` command. If set to `True`, it will first install all the required packages.\n",
    "\n",
    "**Input**: List[texts, dataset]\n",
    "\n",
    "**Output**: List[labels]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['translate', 'exchange_rate', 'car_rental']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "texts = [\n",
    "    \"what expression would i use to say i love you if i were an italian\",\n",
    "    \"what's the currency conversion between krones and yen\",\n",
    "    \"i'd like to reserve a high-end car\"\n",
    "]\n",
    "\n",
    "dataset = [\n",
    "    [\"please help me book a rental car for nashville\",                       \"car_rental\"],\n",
    "    [\"how can i rent a car in boston\",                                       \"car_rental\"],\n",
    "    [\"help me get a rental car for march 2 to 6th\",                          \"car_rental\"],\n",
    "    \n",
    "    [\"how many pesos can i get for one dollar\",                              \"exchange_rate\"],\n",
    "    [\"tell me the exchange rate between rubles and dollars\",                 \"exchange_rate\"],\n",
    "    [\"what is the exchange rate in pesos for 100 dollars\",                   \"exchange_rate\"],\n",
    "    \n",
    "    [\"can you tell me how to say 'i do not speak much spanish', in spanish\", \"translate\"],\n",
    "    [\"please tell me how to ask for a taxi in french\",                       \"translate\"],\n",
    "    [\"how would i say thank you if i were russian\",                          \"translate\"]\n",
    "]\n",
    "\n",
    "model(texts, dataset)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.3 Predict using CLI\n",
    "\n",
    "You can also get predictions in an interactive mode through CLI (Сommand Line Interface)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -m deeppavlov interact few_shot_roberta -d"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`-d` is an optional download key (alternative to `download=True` in Python code). The key `-d` is used to download the pre-trained model along with all other files needed to run the model.\n",
    "\n",
    "Or make predictions for samples from *stdin*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -m deeppavlov predict few_shot_roberta -f <file-name>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5. Customize the model\n",
    "\n",
    "Out-of-scope (OOS) examples are determined via confidence with *confidence_threshold* parameter. For each input text, if the confidence of the model is lower than the *confidence_threshold*, then the input example is considered out-of-scop. The higher the threshold, the more often the model predicts \"oos\" class. By default it is set to 0, but you can change it to your preferences in configuration file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n"
     ]
    }
   ],
   "source": [
    "from deeppavlov import build_model\n",
    "from deeppavlov.core.commands.utils import parse_config\n",
    "\n",
    "model_config = parse_config('few_shot_roberta')\n",
    "model_config['chainer']['pipe'][-1]['confidence_threshold'] = 0.1\n",
    "model = build_model(model_config)"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
